from tqdm import tqdm
from agent import Agents
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

class Runner:
    def __init__(self, args, env):
        self.args = args
        self.episode_limit = args.max_episode_len
        self.env = env
        self.agents = Agents(self.args)
        self.save_path = self.args.save_dir
        self.eps = self.args.eps
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)


    def run(self):
        losses, returns = [], []
        loss = 0
        for time_step in tqdm(range(self.args.time_steps)):
            episode, _ = self.run_episode(evaluate=False)
            loss0 = self.agents.learn(episode)
            loss += np.sum(loss0.detach().cpu().numpy())

            if time_step > 0 and time_step % self.args.evaluate_rate == 0:
                reward = 0
                with torch.no_grad():
                    for _ in range(self.args.evaluate_episodes):
                        reward += self.run_episode(evaluate=True)[1]
                returns.append(reward / self.args.evaluate_episodes)
                losses.append(loss / self.args.evaluate_rate)
                loss = 0

                fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

                axs[0].plot(losses)
                axs[0].set_title('Loss')

                # 绘制返回值子图
                axs[1].plot(returns)
                axs[1].set_title('Returns')

                plt.xlabel('Epoch({} updates)'.format(self.args.evaluate_rate))
                plt.savefig(self.save_path + '/plt.png', format='png')

                plt.close()
                np.save(self.save_path + '/returns', returns)
                np.save(self.save_path + '/loss', losses)


    def run_episode(self, evaluate=False):

        s, available_actions = self.env.reset()
        step = 0
        episode = []
        reward = 0

        m = self.agents.init_message()
        hidden = self.agents.init_hidden()

        m_target = self.agents.init_message()
        hidden_target = self.agents.init_hidden()

        done = np.zeros(self.args.vec_env)
        a_u, u = None, None

        while step < self.episode_limit and not done.all():

            if self.args.scenario_name in ['GuessingNumber']:
                a_u = self.guessing_number_mask(available_actions)
            elif self.args.scenario_name == 'ReachingGoal':
                a_u = self.revealing_goal_mask(available_actions)
            else:
                raise NotImplementedError

            actions, q_value, hidden_next, message, m_next = self.agents.select_actions(s, m, a_u, hidden, 0 if evaluate else self.eps)

            if self.args.scenario_name in ['GuessingNumber']:
                u = self.guessing_number_action(actions, message)
            elif self.args.scenario_name == 'ReachingGoal':
                u = self.revealing_goal_action(actions, message)
            else:
                raise NotImplementedError

            s_next, r, done, available_actions_next = self.env.step(u)

            reward += sum(r[:, 0])/self.args.vec_env

            if not evaluate:

                with torch.no_grad():
                    q_target, hidden_target_next, m_target_next = self.agents.compute_target(s, m, a_u, hidden_target)

                record = {'q_target': q_target,
                          'q_value': q_value,
                          'r': r[:, 0],
                          'done': done,
                          }
                episode.append(record)

                hidden_target = hidden_target_next
                m_target = m_target_next

            s = s_next
            available_actions = available_actions_next

            hidden = hidden_next
            m = m_next

            step += 1

        return episode, reward

    def guessing_number_mask(self, available_actions):
        results = []
        for env_id in range(self.args.vec_env):
            result = []
            for player_id in range(self.args.n_agents):
                if available_actions[env_id][player_id][-1] == 1:
                    action_mask = np.zeros(self.args.action_shape)
                    # reveal_mask = np.zeros(self.args.message_shape)
                    action_mask[-1] = 1
                    # reveal_mask[-1] = 1
                else:
                    action_mask = np.ones(self.args.action_shape)
                    # reveal_mask = np.ones(self.args.message_shape)
                    action_mask[-1] = 0
                    # reveal_mask[-1] = 0
                    action_mask[:self.args.action_shape-1] = available_actions[env_id][player_id][:self.args.action_shape-1]
                # result.append(np.concatenate([action_mask, reveal_mask]))
                result.append(action_mask)
            results.append(result)
        return np.array(results)


    def guessing_number_action(self, actions, message):
        results = []
        for env_id in range(self.args.vec_env):
            result = []
            for player_id in range(self.args.n_agents):
                action_id = actions[env_id][player_id][0]
                message_id = from_one_hot(message[env_id][player_id])
                if action_id == self.args.action_shape - 1 or message_id == self.args.message_shape - 1:
                    assert (action_id == self.args.action_shape-1 and message_id == self.args.message_shape - 1), (
                            'action and reveal are not both NOOP1', action_id, self.args.action_shape - 1, message_id, self.args.message_shape - 1)
                    final_action = self.args.action_shape + self.args.message_shape - 4
                elif action_id == self.args.action_shape - 2:
                    assert (message_id != self.args.message_shape - 2), ('action and reveal are both NOOP2', message_id, self.args.message_shape - 2)
                    final_action = message_id + self.args.action_shape - 2
                elif action_id < self.args.action_shape - 2:
                    assert (message_id == self.args.message_shape -2 ), ('action working, reveal are not NOOP2', message_id, self.args.message_shape - 2)
                    final_action = action_id
                else:
                    raise NotImplementedError(action_id, message_id)
                result.append(final_action)
            results.append(result)
        return np.array(results)

    def revealing_goal_mask(self, available_actions):
        return np.ones([available_actions.shape[0], available_actions.shape[1], self.args.action_shape])

    def revealing_goal_action(self, actions, message):
        results = []
        for env_id in range(self.args.vec_env):
            result = []
            for player_id in range(self.args.n_agents):
                action_id = actions[env_id][player_id][0]
                message_id = from_one_hot(message[env_id][player_id])
                if action_id == self.args.action_shape - 1:
                    assert (message_id != self.args.message_shape - 1), ('action and reveal are both NOOP', message_id, self.args.message_shape - 2)
                    final_action = message_id + self.args.action_shape - 1
                elif action_id < self.args.action_shape - 1:
                    assert (message_id == self.args.message_shape -1 ), ('action working, reveal are not NOOP', message_id, self.args.message_shape - 2)
                    final_action = action_id
                else:
                    raise NotImplementedError(action_id, message_id)
                result.append(final_action)
            results.append(result)
        return np.array(results)

def from_one_hot(one_hot):
    if np.sum(one_hot == 1) != 1 or np.sum(one_hot == 0) != len(one_hot) - 1:
        raise ValueError("input is not one-hot", one_hot)
    index = np.argmax(one_hot)  # 找到值为1的元素的索引
    return index